{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6f938540",
   "metadata": {},
   "source": [
    "# An end-to-end Vertex Batch Prediction Pipeline Demonstration"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62f7de01",
   "metadata": {},
   "source": [
    "Finally, check that you have correctly installed the packages. The KFP SDK version should be >=1.6:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9222df33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KFP SDK version: 1.8.2\r\n"
     ]
    }
   ],
   "source": [
    "!python3 -c \"import kfp; print('KFP SDK version: {}'.format(kfp.__version__))\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "164ff1c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from functools import partial\n",
    "\n",
    "import kfp\n",
    "import pprint\n",
    "import yaml\n",
    "from jinja2 import Template\n",
    "from kfp.v2 import dsl\n",
    "from kfp.v2.compiler import compiler\n",
    "from kfp.v2.dsl import Dataset\n",
    "from kfp.v2.google.client import AIPlatformClient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dbe612cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "project_id='woven-rush-197905'\n",
    "project_number='297370817971'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "27224ab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "af_registry_location='asia-southeast1'\n",
    "af_registry_name='mlops-vertex-kit'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cfe4d02b",
   "metadata": {},
   "outputs": [],
   "source": [
    "components_dir='../components/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "82030efb",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def _load_custom_component(project_id: str,\n",
    "                           af_registry_location: str,\n",
    "                           af_registry_name: str,\n",
    "                           components_dir: str,\n",
    "                           component_name: str):\n",
    "  component_path = os.path.join(components_dir,\n",
    "                                component_name,\n",
    "                                'component.yaml.jinja')\n",
    "  with open(component_path, 'r') as f:\n",
    "    component_text = Template(f.read()).render(\n",
    "      project_id=project_id,\n",
    "      af_registry_location=af_registry_location,\n",
    "      af_registry_name=af_registry_name)\n",
    "\n",
    "  return kfp.components.load_component_from_text(component_text)\n",
    "\n",
    "load_custom_component = partial(_load_custom_component,\n",
    "                                project_id=project_id,\n",
    "                                af_registry_location=af_registry_location,\n",
    "                                af_registry_name=af_registry_name,\n",
    "                                components_dir=components_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "495502ee",
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "preprocess_op = load_custom_component(component_name='data_preprocess')\n",
    "batch_prediction_op = load_custom_component(component_name='batch_prediction')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0026cb75",
   "metadata": {},
   "source": [
    "Then define the pipeline using the following function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8d36a16a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline_region='asia-southeast1'\n",
    "pipeline_root='gs://vertex_pipeline_demo_root/pipeline_root'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "954a7af2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_region='asia-southeast1'\n",
    "input_dataset_uri='bq://woven-rush-197905.vertex_pipeline_demo.banknote_authentication_features'\n",
    "gcs_data_output_folder='gs://vertex_pipeline_demo_root/datasets/prediction'\n",
    "gcs_result_folder='gs://vertex_pipeline_demo_root/prediction'\n",
    "\n",
    "data_pipeline_root='gs://vertex_pipeline_demo_root/compute_root'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a36e0154",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dsl.pipeline(name='batch-prediction-pipeline-template')\n",
    "def pipeline(project_id: str,\n",
    "             data_region: str,\n",
    "             gcs_data_output_folder: str,\n",
    "             input_dataset_uri: str,\n",
    "             data_pipeline_root: str,\n",
    "             gcs_result_folder: str,\n",
    "             model_resource_name: str = '',\n",
    "             endpoint_resource_name: str = '',\n",
    "             machine_type: str = \"n1-standard-8\",\n",
    "             accelerator_count: int = 0,\n",
    "             accelerator_type: str = 'ACCELERATOR_TYPE_UNSPECIFIED',\n",
    "             starting_replica_count: int = 1,\n",
    "             max_replica_count: int = 2):\n",
    "    dataset_importer = kfp.dsl.importer(\n",
    "      artifact_uri=input_dataset_uri,\n",
    "      artifact_class=Dataset,\n",
    "      reimport=False)\n",
    "\n",
    "    preprocess_task = preprocess_op(\n",
    "      project_id=project_id,\n",
    "      data_region=data_region,\n",
    "      gcs_output_folder=gcs_data_output_folder,\n",
    "      gcs_output_format=\"NEWLINE_DELIMITED_JSON\",\n",
    "      input_dataset=dataset_importer.output)\n",
    "\n",
    "    batch_prediction_op(\n",
    "      project_id=project_id,\n",
    "      data_region=data_region,\n",
    "      data_pipeline_root=data_pipeline_root,\n",
    "      gcs_result_folder=gcs_result_folder,\n",
    "      instances_format='jsonl',\n",
    "      predictions_format='jsonl',\n",
    "      input_dataset=preprocess_task.outputs['output_dataset'],\n",
    "      model_resource_name=model_resource_name,\n",
    "      endpoint_resource_name=endpoint_resource_name,\n",
    "      machine_type=machine_type,\n",
    "      accelerator_type=accelerator_type,\n",
    "      accelerator_count=accelerator_count,\n",
    "      starting_replica_count=starting_replica_count,\n",
    "      max_replica_count=max_replica_count)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff5fba56",
   "metadata": {},
   "source": [
    "### Compile and run the end-to-end ML pipeline\n",
    "With our full pipeline defined, it's time to compile it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "794c18bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/luoshixin/Develop/uob-mlops/venv/lib/python3.7/site-packages/kfp/dsl/__init__.py:32: FutureWarning: `kfp.dsl.importer` is a deprecated alias and will be removed in KFP v2.0. Please import from `kfp.v2.dsl` instead.\n",
      "  category=FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "compiler.Compiler().compile(\n",
    "    pipeline_func=pipeline, \n",
    "    package_path=\"batch_prediction_pipeline_job.json\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "575b575f",
   "metadata": {},
   "source": [
    "Next, instantiate an API client:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5ab349ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/luoshixin/Develop/uob-mlops/venv/lib/python3.7/site-packages/kfp/v2/google/client/client.py:173: FutureWarning: AIPlatformClient will be deprecated in v1.9. Please use PipelineJob https://googleapis.dev/python/aiplatform/latest/_modules/google/cloud/aiplatform/pipeline_jobs.html in Vertex SDK. Install the SDK using \"pip install google-cloud-aiplatform\"\n",
      "  category=FutureWarning,\n"
     ]
    }
   ],
   "source": [
    "api_client = AIPlatformClient(\n",
    "    project_id=project_id,\n",
    "    region=pipeline_region)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "951f42d8",
   "metadata": {},
   "source": [
    "Next, kick off a pipeline run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5a3fe7dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "See the Pipeline job <a href=\"https://console.cloud.google.com/vertex-ai/locations/asia-southeast1/pipelines/runs/batch-prediction-pipeline-template-20211008145708?project=woven-rush-197905\" target=\"_blank\" >here</a>."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pipeline_params = {\n",
    "    'project_id': project_id,\n",
    "    'data_region': data_region,\n",
    "    'gcs_data_output_folder': gcs_data_output_folder,\n",
    "    'gcs_result_folder': gcs_result_folder,\n",
    "    'input_dataset_uri': input_dataset_uri,\n",
    "    'data_pipeline_root': data_pipeline_root,\n",
    "    'endpoint_resource_name': 'projects/297370817971/locations/asia-southeast1/endpoints/8843521555783745536',\n",
    "}\n",
    "\n",
    "response = api_client.create_run_from_job_spec(\n",
    "    job_spec_path=\"batch_prediction_pipeline_job.json\", \n",
    "    pipeline_root=pipeline_root,\n",
    "    parameter_values=pipeline_params,\n",
    "    enable_caching=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "269ee125",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "environment": {
   "name": "tf2-gpu.2-5.m74",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-5:m74"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
